Skip to content

Code for "Grassmann Stein Variational Gradient Descent" (AISTATS 2022)

License

Notifications You must be signed in to change notification settings

ImperialCollegeLondon/GSVGD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

75 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GSVGD

[Paper][Slides][Poster]

Test

Test

Test

Data

Covertype data is downloaded from https://archive.ics.uci.edu/ml/datasets/covertype

Other Dependencies

  • Code for Sliced-SVGD is adapted from Wenbo Gong's repo
  • Code for optimization on Grassmann manifold is adapted from Pymanopt

Run experiments

The code below provides an example of running the numerical experiments in the paper.

  • The .sh scripts assume 8 GPUs are available. You can also use CPUs by changing the --gpu argument in these scripts to --gpu=-1.
  • Note: These experiments can take several hours to finish, since they encompass various configurations of dimensions/sample sizes, along with multiple repetitions. If you wish to obtain results faster, you can reduce the number of configurations accordingly in the .sh scripts.

To run:

  1. Install the GSVGD module
pip install -e .
  1. Run experiments (the full experiments can take several hours)
# e.g.1 run multivariate gaussian experiment and generate plots
sh scripts/run_gaussian.sh

# e.g.2 run conditioned diffusion and generate plots
sh scripts/run_diffusion.sh

Basic usage

'''
  distribution: target distribution class with `log_prob` method (up to a constant term)
  kernel: instance of kernel class
  manifold: instance of Grassmann manifold class for projector update
  optimizer: instance of optimizer class for particle update
  
  delta: stepsize for projector update
  T: initial temperature T0
  X: initial particles
  A: initial projectors
  m: number of projectors
  epochs: number of iterations
'''
# instantiate GSVGD class
gsvgd = FullGSVGDBatch(
    target=distribution,
    kernel=kernel,
    manifold=manifold,
    optimizer=optimizer,
    delta=delta,
    T=T
)

# update particles
_ = gsvgd.fit(X=X, A=A, m=m, epochs=epochs, threshold=0.0001*m)

# final particles: X (updates are done in-place)

Run tests

python -m pytest

Code directory

.
│
├───requirements.txt: Dependencies.
├───setup.py: Setup script.
├───data: covertype data.
├───experiments: main scripts for the 5 numerical experiments.
├───plots: folder to hold plots.
├───scripts: Shell scripts to run the experiments and generate plots.
├───src: Source files for implementing each sampling method, and util functions for experiments.
│   ├───Sliced_KSD_Clean: Utils for Sliced SVGD adapted from Wenbo Gong.
│   ├───blr.py: Utils for Bayesian logistic regression.
│   ├───diffusion.py: Utils for conditioned duffition.
│   ├───gsvgd.py: GSVGD class (main).
│   ├───kernel.py: Kernel class.
│   ├───manifold.py: Class for optimisation on the Grassmann manifold, adapted from Pymanopt
│   ├───metrics.py: Metric class for evaluation of results.
│   ├───s_svgd.py: S-SVGD class, adapted from Wenbo Gong.
│   ├───svgd.py: SVGD class.
│   └───utils.py: Other utils.
├───tests: unittests
├───thumbnail: Thumbnail fig.
└───README.md

About

Code for "Grassmann Stein Variational Gradient Descent" (AISTATS 2022)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published